package org.luosl.webmagicx.pipeline

import java.sql.{Connection, DriverManager, ResultSet}
import java.util

import org.apache.commons.dbutils.QueryRunner
import org.eclipse.jetty.util.ConcurrentHashSet
import org.luosl.webmagicx.Utils.SqlUtils
import org.luosl.webmagicx.conf.{Props, SpiderConf}
import org.luosl.webmagicx.pipeline.component.{Distinct, HashSetDistinct}
import us.codecraft.webmagic.{ResultItems, Task}

import scala.collection.JavaConverters._

/**
  * Created by luosl on 2017/12/11.
  */
class SimpleJdbcPipeline(sc:SpiderConf, props:Props) extends BasePipeline(sc, props) {

  /**
    * 数据库表名称
    */
  private val tableName:String = SqlUtils.paramFormat(props.prop("tableName"))

  /**
    * insql sql 和字段信息
    */
  private val (insertSql, dbColumns, fields):(String, List[String], List[String]) = {
    // 获取需要保存的 columns
    val (dbColumns, fields):(List[String], List[String]) = {
      val needSaveFields:String = props.prop("needSaveFields", "*").trim
      val fields:List[String] = if(needSaveFields == "*"){
        "_url" :: sc.fields.map(_.name).toList
      }else{
        needSaveFields.split(",|，").map(_.trim).toList
      }.sorted
      val dbColumns:List[String] = fields.map(SqlUtils.paramFormat)
      (dbColumns, fields)
    }
    // 构建 insertSql 语句
    val dbFields:String = dbColumns.mkString(",")
    val dbParms:String = dbColumns.indices.map(_ => "?").mkString(",")
    val sql:String = s"insert into $tableName($dbFields) values($dbParms)"
    (sql,dbColumns, fields)
  }

  private val threadLocal:ThreadLocal[Connection] = new ThreadLocal[Connection]

  private val allConn = new util.ArrayList[Connection]()

  /**
    * 初始化 distinct
    */
  private val distinctOpt:Option[Distinct] = {
    props.propOption("distinctField").map{ disField=>
      val loadCacheOp:ConcurrentHashSet[Any]=>Unit = (set:ConcurrentHashSet[Any])=> {
        val sql = s"select ${SqlUtils.paramFormat(disField)} from $tableName"
        val sqlOp:ResultSet=>Unit = (rs:ResultSet) => set.add(rs.getObject(1))
        SqlUtils.executeQuery(sql,conn(),sqlOp,autoClose = false)
      }
      val distValOp:ResultItems => Any = (ris:ResultItems) => ris.get(disField).asInstanceOf[Any]
      new HashSetDistinct(loadCacheOp,distValOp)
    }
  }

  checkOrCreateTable()

  /**
    * 创建数据库连接
    * @return
    */
  private def createConn():Connection = {
    val user = props.prop("user")
    val password = props.prop("password")
    val jdbcUrl = props.prop("url")
    Class.forName(props.prop("driver"))
    DriverManager.getConnection(jdbcUrlFormat(jdbcUrl), user, password)
  }

  /**
    * 得到与线程绑定的数据库链接
    * @return
    */
  def conn():Connection = {
    val conn:Connection = threadLocal.get()
    if(null == conn){
      val newConn:Connection = createConn()
      threadLocal.set(newConn)
      allConn.add(newConn)
      newConn
    }else conn
  }
  /**
    * 检查或者创建数据库表
    */
  private def checkOrCreateTable(): Unit ={
    val con:Connection = conn()
    try{
      try {
        new QueryRunner().execute(con,s"select count(*) from $tableName")
        logInfo(s"数据库表:$tableName 已经存在...")
      }catch {
        case _:Exception =>
          logInfo(s"开始创建数据库表:$tableName ...")
          val createTableSql:String =
            s"""
               |CREATE TABLE $tableName (
               |${dbColumns.map(cln=> s"$cln text").mkString(",")}
               |)
            """.stripMargin
          try {
            new QueryRunner().execute(con, createTableSql)
          }catch {
            case e:Exception =>
              logError(s"执行建表语句:$createTableSql 失败!",e)
          }
      }
      val append = props.prop("model") match {
        case "append" =>
          logInfo("当前Model为:append")
        case "override" =>
          logInfo("当前Model为:override")
          logInfo("开始执行 truncate 操作...")
          new QueryRunner().execute(con,s"truncate table $tableName")
        case other:Any =>
          val msg:String = s"无效的Model:$other,CSVPipeline的[model]属性只支持[append,override],请检查配置文件:${sc.path}"
          throw new RuntimeException(msg)
      }
    }finally con.close()

  }

  /**
    * 格式化url
    * @param url url
    * @return
    */
  private def jdbcUrlFormat(url:String):String = {
    if(url.contains("autoReconnect")){
      url
    }else{
      if (url.contains("?")) {
        url + "&autoReconnect=true"
      } else {
        url + "?autoReconnect=true"
      }
    }
  }

  /**
    * 执行插入操作
    * @param insertSql 插入sql模板
    * @param paramMap 参数列表
    */
  private def executeInsert(insertSql:String,paramMap: Map[String,Any]): Unit ={
    val ps = conn().prepareStatement(insertSql)
    val params:List[(Int, Any)] = fields.zipWithIndex.map{ clumnAndIdx =>
      (clumnAndIdx._2 + 1 ,paramMap(clumnAndIdx._1))
    }
    try {
      params.foreach{ tuple =>
        ps.setObject(tuple._1,tuple._2)
      }
      ps.executeUpdate()
    }catch {
      case ne:NoSuchElementException =>
        logError(s"无效的参数[${ne.getMessage}] 在[$insertSql]",ne)
      case e:Exception =>
        logError(s"执行sql[$insertSql]失败!参数列表:${params.mkString(",")}",e)
    } finally {
      if(null != ps) ps.close()
    }
  }

  override def process(resultItems: ResultItems, task: Task): Unit = {
    /**
      * 执行保存
      */
    def save():Unit = {
      val paramMap:Map[String, Any] = resultItems.getAll.asScala.map(t=>t._1->t._2.asInstanceOf[Any]).toMap
      try {
        executeInsert(insertSql,paramMap)
        logInfo(s"保存[url=${paramMap("_url")}]成功.....")
      }catch {
        case e:Exception =>
          logError(s"保存[insertSql=$insertSql,params=${paramMap.mkString(",")},url=${paramMap("_url")}]失败.....",e)
          onError(resultItems, task)
      }
    }
    // 判断是否去重
    if(distinctOpt.isDefined){
      if(distinctOpt.get.isUniqueAndAdd(resultItems)){
        save()
        onSuccess(resultItems, task)
      }else{
        onSkip(resultItems, task)
      }
    }else{
      save()
      onSuccess(resultItems, task)
    }

  }

  override def close(): Unit ={
    allConn.forEach{ conn=>
      if(!conn.isClosed){
        conn.close()
      }
    }
  }

}
